import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, n_features, n_hidden, n_classes, dropout):
        super(GAT, self).__init__()

        self.gc1 = GATConv(in_channels = n_features, out_channels = n_hidden, dropout = dropout)
        self.gc2 = GATConv(in_channels = n_hidden, out_channels = n_classes, dropout = dropout)
         
        self.dropout = dropout

    def forward(self, x, edge_index): 
        x = F.relu(self.gc1(x, edge_index))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, edge_index)
        return F.log_softmax(x, dim=1)